/******************************************************************************
 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

#include <cuda_fp16.h>

#include <cmath>

namespace fmha {

////////////////////////////////////////////////////////////////////////////////////////////////////

inline __device__ float apply_exp_(float x, float max) {
    return __expf(x - max);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

inline __device__ float apply_exp2_(float x, float max) {
    return exp2f(x - max);
    // With fast-math, this produces the same PTX instruction as the assembly below
    // float diff = x - max;
    // float res;
    // asm ("ex2.approx.ftz.f32 %0, %1;\n\t" : "=f"(res) : "f"(diff));
    // return res;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int COLS>
struct ReadType {};
template <>
struct ReadType<4> {
    using T = float;
};
template <>
struct ReadType<8> {
    using T = float2;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Cta_tile, typename Kernel_traits>
struct Smem_tile_reduce {
    // Helper class to distribute MMA tiles reduced over rows per warp over quads.

    // The Mma tile.
    using Mma_tile = fmha::Hmma_tile<Cta_tile>;

    // The number of MMAs in M/N dimensions.
    static constexpr int MMAS_M = Mma_tile::MMAS_M;
    static constexpr int MMAS_N = Mma_tile::MMAS_N;

    static constexpr int WARPS_M = Cta_tile::WARPS_M;
    static constexpr int WARPS_N = Cta_tile::WARPS_N;

    static constexpr int ROWS = WARPS_M * MMAS_M * 16;
    static constexpr int COLS = WARPS_N;
    static_assert(COLS == 4 || COLS == 8);
    static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;
    static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);
    static constexpr int ELTS_PER_TILE = ROWS * COLS;

    static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW;
    // TD [2022-05-02]: No longer true if head_dim != 64
    // static_assert(THREADS_PER_GROUP == 16); // DEBUG
    static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP;
    static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS;
    static_assert(LOOPS == 1);

    using read_t = typename ReadType<COLS>::T;

    __device__ inline Smem_tile_reduce(float *smem_, const int tidx) {
        int lane = tidx % 32;
        int warp = tidx / 32;

        int warp_m = warp % WARPS_M;
        int warp_n = warp / WARPS_M;

        qid_ = lane % 4;
        int qp = lane / 4;

        // Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps.
        // This won't affect reading as we assume commutative reduction ops.
        const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN);
        smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col];
        smem_read_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_];
        smem_read_row_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qid_];
    }

    __device__ inline void store(float (&frag)[2 * MMAS_M]) {
        if (qid_ == 0) {
#pragma unroll
            for (int mi = 0; mi < MMAS_M; mi++) {
                int offset = mi * 16 * WARPS_N;
                smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0];
                smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1];
            }
        }
    }

    __device__ inline void load(read_t (&frag)[2 * MMAS_M]) {
#pragma unroll
        for (int mi = 0; mi < MMAS_M; mi++) {
            int offset = mi * 16 * 4;
            frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4];
            frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4];
        }
    }

    __device__ inline void load_row(read_t (&frag)[MMAS_M], int row) {
#pragma unroll
        for (int mi = 0; mi < MMAS_M; mi++) {
            int offset = mi * 16 * 4;
            frag[mi] = smem_read_row_[offset + 0 * 8 * 4 + row * 4];
        }
    }

    int qid_;
    float *smem_write_;
    read_t *smem_read_;
    read_t *smem_read_row_;
};

template <typename Cta_tile, typename Kernel_traits>
struct Softmax_base {
    // The Mma tile.
    using Mma_tile = fmha::Hmma_tile<Cta_tile>;

    // The number of MMAs in M/N dimensions.
    static constexpr int MMAS_M = Mma_tile::MMAS_M;
    static constexpr int MMAS_N = Mma_tile::MMAS_N;

    // The number of groups of warp such that we have at most 4 warps writing consecutive elements.
    static constexpr int GROUPS = fmha::DivUpConstexpr(Cta_tile::WARPS_N, 4);
    // The number of elements that we are going to store per row.
    static constexpr int ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS;
    // The number of rows.
    static constexpr int ROWS = Cta_tile::M * GROUPS;
    // The total number of elements.
    static constexpr int ELEMENTS = ROWS * ELEMENTS_PER_ROW;

    // Ctor.
    template <typename Params>
    inline __device__ Softmax_base(const Params &params, void *smem, int tidx)
        :  // packed_mask_ptr_(reinterpret_cast<const char*>(params.packed_mask_ptr)),
          smem_(reinterpret_cast<float *>(smem)),
          tidx_(tidx) {
        // Move to the 1st mask loaded by the thread+ tidx;
        // packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t);

        // Extract the position in the warp.
        int warp = tidx / Cta_tile::THREADS_PER_WARP;
        int lane = tidx % Cta_tile::THREADS_PER_WARP;

        // Decompose the warp index into M and N.
        int warp_m = warp % Cta_tile::WARPS_M;
        int warp_n = warp / Cta_tile::WARPS_M;

        // Decompose the warp-n index into group/position-inside-the-group.
        int warp_g = warp_n / ELEMENTS_PER_ROW;
        int warp_i = warp_n % ELEMENTS_PER_ROW;

        // The location written by the threads.
        int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4;
        int write_col = warp_i;

        // Assemble the write pointer.
        smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col];

        // Assemble the read pointer.
        smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4];
    }

    template <bool zero = false, typename Mask>
    inline __device__ void apply_mask(const Mask &mask) {
#pragma unroll
        for (int mi = 0; mi < MMAS_M; ++mi) {
#pragma unroll
            for (int ii = 0; ii < 2; ++ii) {
#pragma unroll
                for (int ni = 0; ni < MMAS_N; ++ni) {
#pragma unroll
                    for (int jj = 0; jj < 4; ++jj) {
                        if (!mask.is_valid(mi, ni, ii, jj)) {
                            elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY;
                        }
                    }
                }
            }
        }
    }

    // Apply the exp to all the elements.
    template <bool max_in_base2 = false, bool elt_in_base2 = false>
    inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) {
#pragma unroll
        for (int mi = 0; mi < MMAS_M * 2; ++mi) {
            // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
            // max * log_2(e)) This allows the compiler to use the ffma
            // instruction instead of fadd and fmul separately.
            constexpr float kLog2e = M_LOG2E;
            const float max_base2 = max_in_base2 ? max[mi] : max[mi] * kLog2e;
#pragma unroll
            for (int ni = 0; ni < MMAS_N * 4; ++ni) {
                // elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]);
                elt_[mi][ni] = apply_exp2_(elt_in_base2 ? elt_[mi][ni] : elt_[mi][ni] * kLog2e, max_base2);
            }
        }
    }

    // Apply the exp to all the elements.
    template <bool scale_max = true>
    inline __device__ void scale_apply_exp(const float (&max)[MMAS_M * 2], const float scale_) {
        const float max_scale = scale_max ? scale_ * M_LOG2E : M_LOG2E;
        const float scale = scale_ * M_LOG2E;
#pragma unroll
        for (int mi = 0; mi < MMAS_M * 2; ++mi) {
            // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
            // max * log_2(e)) This allows the compiler to use the ffma
            // instruction instead of fadd and fmul separately.
            const float max_scaled = max[mi] * max_scale;
#pragma unroll
            for (int ni = 0; ni < MMAS_N * 4; ++ni) {
                elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * scale, max_scaled);
            }
        }
    }

    // Apply the exp to all the elements.
    template <bool max_in_base2 = false>
    inline __device__ void apply_exp_col(const float (&max)[MMAS_N * 4]) {
#pragma unroll
        for (int ni = 0; ni < MMAS_N * 4; ++ni) {
            constexpr float kLog2e = M_LOG2E;
            const float max_base2 = max_in_base2 ? max[ni] : max[ni] * kLog2e;
#pragma unroll
            for (int mi = 0; mi < MMAS_M * 2; ++mi) {
                elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2);
            }
        }
    }
    // inline __device__ void apply_exp_col(const float (&max)[MMAS_N]) {
    //     constexpr float kLog2e = M_LOG2E;
    //     #pragma unroll
    //     for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
    //         float max_base2 = max_in_base2 ? max[ni / 4] : max[ni / 4] * kLog2e;
    //         max_base2 = __shfl_sync(0xffffffff, max_base2, (ni % 4) * 8 + threadIdx.x % 8);
    //         #pragma unroll
    //         for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
    //             elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2);
    //         }
    //     }
    // }

    // Scale all the elements.
    inline __device__ void scale(const float (&sum)[MMAS_M * 2]) {
        // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
        float inv_sum[MMAS_M * 2];
#pragma unroll
        for (int mi = 0; mi < MMAS_M * 2; ++mi) {
            inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi];
        }

// Update the values.
#pragma unroll
        for (int mi = 0; mi < MMAS_M * 2; ++mi) {
#pragma unroll
            for (int ni = 0; ni < MMAS_N * 4; ++ni) {
                elt_[mi][ni] *= inv_sum[mi];
            }
        }
    }

    // Subtract all elements by dp_sum
    inline __device__ void subtract_dp_sum(const float (&dp_sum)[MMAS_M * 2]) {
#pragma unroll
        for (int mi = 0; mi < MMAS_M * 2; ++mi) {
#pragma unroll
            for (int ni = 0; ni < MMAS_N * 4; ++ni) {
                elt_[mi][ni] -= dp_sum[mi];
            }
        }
    }

    // The pointer to the mask.
    const char *packed_mask_ptr_;
    // Shared memory for the CTA-wide reduction.
    float *smem_, *smem_write_, *smem_read_;
    // The current thread index.
    int tidx_;
    // The elements.
    float elt_[MMAS_M * 2][MMAS_N * 4];
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Cta_tile, typename Kernel_traits>
struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
    // The base class.
    using Base = Softmax_base<Cta_tile, Kernel_traits>;
    // The fragment.
    using Fragment_a = fmha::Fragment_a<fmha::Row>;

    static_assert(Fragment_a::NUM_REGS == 4);

    static constexpr int WARPS_M = Cta_tile::WARPS_M;
    static constexpr int WARPS_N = Cta_tile::WARPS_N;
    // The MMAs.
    static constexpr int MMAS_M = Base::MMAS_M;
    static constexpr int MMAS_N = Base::MMAS_N;

    // The accumulators.
    using Accumulator = fmha::Fragment_accumulator;
    using Accumulator_out = Fragment<uint16_t, 8>;
    static_assert(Accumulator_out::NUM_REGS == 4);

    static_assert(std::is_same<Accumulator::Data_type, float>::value);

    using Smem_tile_red = Smem_tile_reduce<Cta_tile, Kernel_traits>;
    static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N);
    // Ctor.
    template <typename Params>
    inline __device__ Softmax(const Params &params, void *smem, int tidx)
        : Base(params, smem, tidx),
          params_scale_bmm1_(params.scale_bmm1),
          smem_sum_(static_cast<float *>(smem), tidx),
          smem_max_(static_cast<float *>(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) {}

    // Pack the data to a fragment for the next GEMM.
    template <typename elem_type = __half, int K, int M>
    inline __device__ void pack(Fragment_a (&dst)[K][M]) const {
#pragma unroll
        for (int mi = 0; mi < M; ++mi) {
#pragma unroll
            for (int ki = 0; ki < K; ++ki) {
                // 1st row - 4 elements per row.
                float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0];
                float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1];
                float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2];
                float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3];

                // 2nd row - 4 elements per row.
                float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0];
                float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1];
                float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2];
                float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3];

                // Pack to 4 registers.
                dst[ki][mi].reg(0) = fmha::float2_pack<elem_type>(tmp_00, tmp_01);
                dst[ki][mi].reg(1) = fmha::float2_pack<elem_type>(tmp_10, tmp_11);
                dst[ki][mi].reg(2) = fmha::float2_pack<elem_type>(tmp_02, tmp_03);
                dst[ki][mi].reg(3) = fmha::float2_pack<elem_type>(tmp_12, tmp_13);
            }
        }
    }

    // Scale FP32 fragments
    inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) {
        const float scalef = reinterpret_cast<const float &>(this->params_scale_bmm1_);

#pragma unroll
        for (int mi = 0; mi < MMAS_M; ++mi) {
#pragma unroll
            for (int ni = 0; ni < MMAS_N; ++ni) {
                // 1st row - 4 elements per row.
                this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef;
                this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef;
                this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef;
                this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef;
                // 2nd row - 4 elements per row.
                this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef;
                this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef;
                this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef;
                this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef;
            }
        }
    }

    // Scale FP32 fragments
    inline __device__ void unpack_noscale(const Accumulator (&acc)[MMAS_M][MMAS_N]) {
#pragma unroll
        for (int mi = 0; mi < MMAS_M; ++mi) {
#pragma unroll
            for (int ni = 0; ni < MMAS_N; ++ni) {
                // 1st row - 4 elements per row.
                this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0);
                this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1);
                this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4);
                this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5);
                // 2nd row - 4 elements per row.
                this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2);
                this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3);
                this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6);
                this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7);
            }
        }
    }

    template <bool zero_init = true, typename Operator>
    __device__ inline void thread_reduce_(float (&frag)[2 * MMAS_M], Operator &op) {
#pragma unroll
        for (int mi = 0; mi < 2 * MMAS_M; mi++) {
            frag[mi] = zero_init ? this->elt_[mi][0] : op(frag[mi], this->elt_[mi][0]);
#pragma unroll
            for (int ni = 1; ni < 4 * MMAS_N; ni++) {
                frag[mi] = op(frag[mi], this->elt_[mi][ni]);
            }
        }
    }

    template <bool zero_init = true, typename Operator>
    __device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red &smem_red) {
        thread_reduce_<zero_init>(frag, op);
        quad_reduce(frag, frag, op);
        smem_red.store(frag);
        __syncthreads();
        typename Smem_tile_red::read_t tmp[2 * MMAS_M];
        smem_red.load(tmp);
        quad_allreduce(frag, tmp, op);
    }

    template <bool zero_init = true>
    __device__ inline void reduce_max(float (&frag)[2 * MMAS_M]) {
        MaxOp<float> max;
        reduce_<zero_init>(frag, max, smem_max_);
    }

    __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]) {
        SumOp<float> sum;
        reduce_(frag, sum, smem_sum_);
    }

    template <bool zero_init = true>
    __device__ inline void reduce_sum_before_sync_(float (&frag)[2 * MMAS_M]) {
        SumOp<float> sum;
        thread_reduce_<zero_init>(frag, sum);
        quad_reduce(frag, frag, sum);
        smem_sum_.store(frag);
    }

    template <int NROWS, typename Operator>
    __device__ inline void reduce_after_sync_(float (&frag)[NROWS][MMAS_M], const int (&rows)[NROWS], Operator &op,
                                              Smem_tile_red &smem_red) {
#pragma unroll
        for (int ii = 0; ii < NROWS; ii++) {
            typename Smem_tile_red::read_t tmp[MMAS_M];
            smem_red.load_row(tmp, rows[ii]);
            quad_allreduce(frag[ii], tmp, op);
        }
    }

    template <int NROWS>
    __device__ inline void reduce_sum_after_sync_(float (&frag)[NROWS][MMAS_M], const int (&rows)[NROWS]) {
        SumOp<float> sum;
        reduce_after_sync_(frag, rows, sum, smem_sum_);
    }

    template <int NROWS>
    __device__ inline void reduce_max_after_sync_(float (&frag)[NROWS][MMAS_M], const int (&rows)[NROWS]) {
        MaxOp<float> max;
        reduce_after_sync_(frag, rows, max, smem_max_);
    }

    const uint32_t params_scale_bmm1_;
    Smem_tile_red smem_max_;
    Smem_tile_red smem_sum_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace fmha
